import datetime
from sqlalchemy.engine import engine_from_config
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, scoped_session, sessionmaker
from sqlalchemy.schema import Column, ForeignKey
from sqlalchemy.types import DateTime, Integer, Unicode, Numeric, Date
from stockforecaster.lib import ystockquote
from stockforecaster.lib.pagination import Pager
from zope.sqlalchemy import ZopeTransactionExtension


# How far back to go for stock quotes
START_DATE = datetime.date(2010, 1, 1)

DBSession = scoped_session(sessionmaker(extension=ZopeTransactionExtension()))


def format_request_date(dt):
    return dt.strftime('%Y%m%d')


def format_response_date(dt):
    return datetime.datetime.strptime(dt, '%Y-%m-%d')


def initialize_sqla(settings):
    engine = engine_from_config(settings, 'sqlalchemy.')
    DBSession.configure(bind=engine)
    Base.metadata.bind = engine


def initialize_db(settings):
    initialize_sqla(settings)
    Base.metadata.create_all()


class Base(object):

    query = DBSession.query_property()

    def __str__(self):
        return unicode(self).encode('utf-8')

    def __repr__(self):
        return repr(unicode(self))

    def __unicode__(self):
        return self.__class__.__name__

    @classmethod
    def get_by_id(cls, id):
        return cls.query.get(id)

    @classmethod
    def create(cls, **kwargs):
        obj = cls(**kwargs)
        DBSession.add(obj)
        DBSession.flush()
        return obj

    def delete(self):
        DBSession.delete(self)


Base = declarative_base(cls=Base)


class Company(Base):
    
    __tablename__ = 'companies'
    id = Column(Integer(unsigned=True), primary_key=True)
    symbol = Column(Unicode(10), unique=True)
    name = Column(Unicode(40), unique=True)
    exchange = Column(Unicode(20), index=True)
    added_at = Column(DateTime, default=datetime.datetime.utcnow)
    
    def __unicode__(self):
        return self.symbol
    
    @property
    def current_price(self):
        return ystockquote.get_price(self.symbol)
        
    @property
    def chart(self):
        return ystockquote.get_chart(self.symbol)
    
    @property
    def yahoo_finance_url(self):
        return 'http://finance.yahoo.com/q?s=%s' % self.symbol
    
    @property
    def yahoo_finance_feed_url(self):
        return 'http://feeds.finance.yahoo.com/rss/2.0/headline?s=%s&region=US' % self.symbol

    @classmethod
    def get_by_symbol(cls, symbol):
        return cls.query.filter(cls.symbol==symbol.upper()).first()
    
    @classmethod
    def get_companies(cls, limit=50, page=1):
        pager = Pager(page, limit)
        return (Company.query
                .order_by(Company.symbol.asc())
                .limit(pager.limit)
                .offset(pager.offset)
                .all())

    @classmethod
    def count_companies(cls):
        return Company.query.count()
    
    def get_all_stock_quotes(self):
        return (self.stock_quotes
                .order_by(StockQuote.date.asc())
                .all())
    
    def get_stock_quotes(self, limit=50, page=1):
        pager = Pager(page, limit)
        return (self.stock_quotes
                .order_by(StockQuote.date.desc())
                .limit(pager.limit)
                .offset(pager.offset)
                .all())

    def count_stock_quotes(self):
        return self.stock_quotes.count()
    
    def fetch_stock_quotes(self):
        start_date = START_DATE
        end_date = datetime.datetime.utcnow()
        last_stock_quote = self._get_last_stock_quote()
        if last_stock_quote:
            start_date = last_stock_quote.date
        data = ystockquote.get_historical_prices(
            self.symbol, format_request_date(start_date), format_request_date(end_date))
        for row in data:
            row = dict(row)
            quote = {'company': self,
                     'open_price': row['Open'],
                     'high_price': row['High'],
                     'low_price': row['Low'],
                     'close_price': row['Close'],
                     'adjusted_close_price': row['Adj Clos'],
                     'volume': row['Volume'],
                     'date': format_response_date(row['Date'])}
            if not StockQuote.exists(self, quote['date']):
                StockQuote.create(**quote)
                
    def get_previous_company(self):
        '''
        In alphabetical order, return the previous company in the list.
        '''
        return (Company.query
                .order_by(Company.symbol.desc())
                .filter(Company.symbol<self.symbol)
                .first())
    
    def get_next_company(self):
        '''
        In alphabetical order, return the next company in the list.
        '''
        return (Company.query
                .order_by(Company.symbol.asc())
                .filter(Company.symbol>self.symbol)
                .first())

    def _get_last_stock_quote(self):
        return (self.stock_quotes
                .order_by(StockQuote.date.desc())
                .first())

class StockQuote(Base):
     
    __tablename__ = 'stock_quotes'
    id = Column(Integer(unsigned=True), primary_key=True)
    company_id = Column(Integer(unsigned=True), ForeignKey('companies.id', ondelete='cascade'))
    open_price = Column(Numeric(8, 2), nullable=False)
    high_price = Column(Numeric(8, 2), nullable=False)
    low_price = Column(Numeric(8, 2), nullable=False)
    close_price = Column(Numeric(8, 2), nullable=False)
    adjusted_close_price = Column(Numeric(8, 2), nullable=False)
    volume = Column(Integer(unsigned=True), nullable=False)
    date = Column(Date, nullable=False)
    
    company = relationship(
        'Company', backref=backref('stock_quotes', lazy='dynamic'), lazy='joined')
    
    def __unicode__(self):
        return self.close_price
    
    @classmethod
    def exists(cls, company, date):
        return (cls.query
                .filter(cls.company==company)
                .filter(cls.date==date)
                .first())
